{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "from torchvision import datasets,transforms\n",
    "\n",
    "data_path =\"../data\"\n",
    "digit_train = datasets.MNIST(data_path,train=True,download=False,transform=transforms.ToTensor())\n",
    "digit_test = datasets.MNIST(data_path,train=False,download=False,transform=transforms.ToTensor())"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:10:23.423701100Z",
     "start_time": "2023-10-20T12:09:52.864803900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "(torch.Size([1, 28, 28, 10000]), torch.Size([1, 28, 28, 60000]))"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "digit_value_train = torch.stack([digit for digit,_ in digit_train],dim=3)\n",
    "digit_value_test = torch.stack([digit for digit,_ in digit_test],dim=3)\n",
    "digit_value_test.shape,digit_value_train.shape"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:11:00.448050Z",
     "start_time": "2023-10-20T12:10:49.664052500Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集均值 tensor(0.1325)\n",
      "训练集均值 tensor(0.1307)\n"
     ]
    }
   ],
   "source": [
    "print('测试集均值',torch.mean(digit_value_test))\n",
    "print('训练集均值',torch.mean(digit_value_train))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:11:03.924364300Z",
     "start_time": "2023-10-20T12:11:03.829371600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集标准差 tensor(0.3105)\n",
      "训练集标准差 tensor(0.3081)\n"
     ]
    }
   ],
   "source": [
    "print('测试集标准差',torch.std(digit_value_test))\n",
    "print('训练集标准差',torch.std(digit_value_train))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:11:04.912370700Z",
     "start_time": "2023-10-20T12:11:04.823376Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAa8UlEQVR4nO3dcWyU953n8c/EwIRw42ko2DMOxvK20KSYow0hgEXAcMWLu0EhpCeSbCuzalHSGFTqZKNS/sAX6XCOHixtaeglWxG4hIJOm5B0QSHOgU05QtZQZ8PSiDOHKW6xz8JLPMYhA4bf/cEx2wEC+U1m+DLj90saKZ55vjw/nj7lzcOMHwecc04AABi4zXoBAICBiwgBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzg6wXcKWLFy/q5MmTCoVCCgQC1ssBAHhyzqm3t1dFRUW67bbrX+vcchE6efKkiouLrZcBAPic2tvbNWrUqOtuc8tFKBQKSZKm6ZsapMHGqwEA+OrXee3VjsSf59eTsQi98MIL+slPfqKOjg6NGzdOa9eu1QMPPHDDucv/BDdIgzUoQIQAIOv8/zuSfpa3VDLywYStW7dq6dKlWr58uVpaWvTAAw+oqqpKJ06cyMTuAABZKiMRWrNmjb773e/qe9/7nu655x6tXbtWxcXFWr9+fSZ2BwDIUmmP0Llz53Tw4EFVVlYmPV9ZWal9+/ZdtX08HlcsFkt6AAAGhrRH6NSpU7pw4YIKCwuTni8sLFRnZ+dV29fX1yscDicefDIOAAaOjH2z6pVvSDnnrvkm1bJly9TT05N4tLe3Z2pJAIBbTNo/HTdixAjl5eVdddXT1dV11dWRJAWDQQWDwXQvAwCQBdJ+JTRkyBBNnDhRDQ0NSc83NDSovLw83bsDAGSxjHyfUG1trb7zne/ovvvu09SpU/Xiiy/qxIkTevLJJzOxOwBAlspIhBYsWKDu7m4999xz6ujoUFlZmXbs2KGSkpJM7A4AkKUCzjlnvYg/F4vFFA6HVaGHuGMCAGShfndejXpDPT09ys/Pv+62/CgHAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMBMRu6iDZi7LS+lsZPPTPae6Rv3iffMmOrfec8AuYgrIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJjhLtrISYGJX01p7sAPfuo985eH/2NK+wLAlRAAwBARAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYbmCIn/dXLe27avk4cKfSeGaPj6V8IkIW4EgIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzHADU9xUgWDQe+Z/vzjOe+Zvwi94z0jS3W8t9p75yg8Pes847wkgN3ElBAAwQ4QAAGbSHqG6ujoFAoGkRyQSSfduAAA5ICPvCY0bN07vvPNO4uu8vLxM7AYAkOUyEqFBgwZx9QMAuKGMvCfU2tqqoqIilZaW6tFHH9WxY8c+ddt4PK5YLJb0AAAMDGmP0OTJk7Vp0ybt3LlTL730kjo7O1VeXq7u7u5rbl9fX69wOJx4FBcXp3tJAIBbVNojVFVVpUceeUTjx4/XN77xDW3fvl2StHHjxmtuv2zZMvX09CQe7e3t6V4SAOAWlfFvVh02bJjGjx+v1tbWa74eDAYVTOEbGAEA2S/j3ycUj8f14YcfKhqNZnpXAIAsk/YIPfPMM2pqalJbW5vee+89fetb31IsFlN1dXW6dwUAyHJp/+e4P/7xj3rsscd06tQpjRw5UlOmTNH+/ftVUlKS7l0BALJc2iO0ZcuWdP+SyCHdj9/rPfPhN37mPTNpTa33jCSNXb3Pe4abkQKp495xAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAICZjP9QO+DP9c877T3z+pkC75lRG494z0jShZSmkKpBo+7yH3L+t4zt/9NJ//3gpuBKCABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGa4izZS1vPXU7xn3vj6f/We+cu/f9Z7ZvSpfd4z+Dd5+fneM0fqvuo98w8P/9R7JhWLjzyW0tywOcfSvBJciSshAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMNzBFyk59M+49s+r//gfvmdHPcTPSVH3y4P0pzf305z/3nrlnyP9MYU835+/Bvxn3Skpz8+Yu9Z65/Tf/lNK+BiquhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM9zAFClrmOZ/k8uHWhZ5z0T1ofdMLsr7Qth75uv/6Xcp7eueIf5/P53W8tfeM8ENd3rPfOu5nd4zfxNO7RzqH8rf0zONIwwAMEOEAABmvCO0Z88ezZ07V0VFRQoEAtq2bVvS68451dXVqaioSEOHDlVFRYUOHz6crvUCAHKId4T6+vo0YcIErVu37pqvr1q1SmvWrNG6devU3NysSCSi2bNnq7e393MvFgCQW7w/mFBVVaWqqqprvuac09q1a7V8+XLNnz9fkrRx40YVFhZq8+bNeuKJJz7fagEAOSWt7wm1tbWps7NTlZWVieeCwaBmzJihffuu/SOa4/G4YrFY0gMAMDCkNUKdnZ2SpMLCwqTnCwsLE69dqb6+XuFwOPEoLi5O55IAALewjHw6LhAIJH3tnLvqucuWLVumnp6exKO9vT0TSwIA3ILS+s2qkUhE0qUromg0mni+q6vrqqujy4LBoILBYDqXAQDIEmm9EiotLVUkElFDQ0PiuXPnzqmpqUnl5eXp3BUAIAd4XwmdOXNGR48eTXzd1tam999/X8OHD9fo0aO1dOlSrVy5UmPGjNGYMWO0cuVK3XHHHXr88cfTunAAQPbzjtCBAwc0c+bMxNe1tbWSpOrqar388st69tlndfbsWT311FM6ffq0Jk+erLfffluhUCh9qwYA5ISAc85ZL+LPxWIxhcNhVeghDQoMtl7OgODKJ6Q0V7tpi/fMzyb5/7PshdOnvWdyUf87o71ndtzzDynta+rBb3vPjJx39MYbXeniBe+Rrxzw/3PhByN3e89I0lMl01KaG+j63Xk16g319PQoPz//utty7zgAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYSetPVkV26vjb8ynN/en8nd4z3BH7krPz7vee2fPVF71n/rbTfz+SFP1+r/dMfwp3xE7FjiNl3jM/TPEu2sg8roQAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADPcwBT69pebU5pbeaDKe+bLaklpX7km8sz/8Z457/xvELpn7RTvGUn6wp/eTWnuZvjmV/7Fe2Zey6KU9hXRhynN4bPjSggAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMMMNTJGyiaUnvGd6MrCObPTQyPe9Z94+O8x75ovbj3jPSJL/rVJTc3Ha17xnnilY5z3zj/vu9Z6RpEhKU/DBlRAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYbmCJlLXvHes/8hd7NwEqM3T/ee2R8cL/3zDtnvuo9c6H7X71nUpV3553eM+W/+CfvmcK8oPfM3T/r8p6Rbt6NXAcyroQAAGaIEADAjHeE9uzZo7lz56qoqEiBQEDbtm1Len3hwoUKBAJJjylTpqRrvQCAHOIdob6+Pk2YMEHr1n36D5aaM2eOOjo6Eo8dO3Z8rkUCAHKT9wcTqqqqVFVVdd1tgsGgIhF+JiEA4Poy8p5QY2OjCgoKNHbsWC1atEhdXZ/+yZR4PK5YLJb0AAAMDGmPUFVVlV599VXt2rVLq1evVnNzs2bNmqV4PH7N7evr6xUOhxOP4uLidC8JAHCLSvv3CS1YsCDx32VlZbrvvvtUUlKi7du3a/78+Vdtv2zZMtXW1ia+jsVihAgABoiMf7NqNBpVSUmJWltbr/l6MBhUMOj/zWcAgOyX8e8T6u7uVnt7u6LRaKZ3BQDIMt5XQmfOnNHRo0cTX7e1ten999/X8OHDNXz4cNXV1emRRx5RNBrV8ePH9eMf/1gjRozQww8/nNaFAwCyn3eEDhw4oJkzZya+vvx+TnV1tdavX69Dhw5p06ZN+uijjxSNRjVz5kxt3bpVoVAofasGAOQE7whVVFTIOfepr+/cufNzLQjZY1rFv3jPnMzAOqydvuffec98ZXCe98yC/zHHe2a09nnPpOros3d7z7w+4m3vmRn//Jj3zJ0dHd4zuDm4dxwAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMZPwnqyJ3VY/8X94z9fr3GVgJ0u30wqneM7/7zt+lsCf/u4kP/ekXvGcu9h298UYwwZUQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGG5hCmzbPTmmupuaQ90zrzyd7z4xZ8p73TC4a8vXT3jOBSeNT2lftsi3eM8f6/fezZHGN98zQxn/2nnHeE7hZuBICAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMxwA1NoVP2+lOYenPmo98zLf/XfvGe+N6Lae+bLi9u9ZyTpQve/es988sVASvvydWDSK94zF7ddzMBKrq3sv//Qe+Yv/vFd7xluRppbuBICAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMxwA1Ok7KPfFHnP9P7gdu+ZQ9P/3nvmqe0zvWckqbf/i94z/+Wuv0thT7n397/QcesVIBvl3v8TAABZgwgBAMx4Rai+vl6TJk1SKBRSQUGB5s2bpyNHjiRt45xTXV2dioqKNHToUFVUVOjw4cNpXTQAIDd4RaipqUk1NTXav3+/Ghoa1N/fr8rKSvX19SW2WbVqldasWaN169apublZkUhEs2fPVm9vb9oXDwDIbl4fTHjrrbeSvt6wYYMKCgp08OBBTZ8+Xc45rV27VsuXL9f8+fMlSRs3blRhYaE2b96sJ554In0rBwBkvc/1nlBPT48kafjw4ZKktrY2dXZ2qrKyMrFNMBjUjBkztG/ftX+EdDweVywWS3oAAAaGlCPknFNtba2mTZumsrIySVJnZ6ckqbCwMGnbwsLCxGtXqq+vVzgcTjyKi4tTXRIAIMukHKHFixfrgw8+0K9//eurXgsEAklfO+eueu6yZcuWqaenJ/Fob29PdUkAgCyT0jerLlmyRG+++ab27NmjUaNGJZ6PRCKSLl0RRaPRxPNdXV1XXR1dFgwGFQwGU1kGACDLeV0JOee0ePFivfbaa9q1a5dKS0uTXi8tLVUkElFDQ0PiuXPnzqmpqUnl5eXpWTEAIGd4XQnV1NRo8+bNeuONNxQKhRLv84TDYQ0dOlSBQEBLly7VypUrNWbMGI0ZM0YrV67UHXfcoccffzwjvwEAQPbyitD69eslSRUVFUnPb9iwQQsXLpQkPfvsszp79qyeeuopnT59WpMnT9bbb7+tUCiUlgUDAHJHwDnnrBfx52KxmMLhsCr0kAYFBlsvB2l2cdrXvGeOLvR/67Kpcq33jCQV5vm/P7m1N3rjja6wYu8875lgftx75j9/7Q3vGUmqe+nb3jNFq9/z39HFC/4zuOX1u/Nq1Bvq6elRfn7+dbfl3nEAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAww120AQBpxV20AQBZgQgBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGDGK0L19fWaNGmSQqGQCgoKNG/ePB05ciRpm4ULFyoQCCQ9pkyZktZFAwByg1eEmpqaVFNTo/3796uhoUH9/f2qrKxUX19f0nZz5sxRR0dH4rFjx460LhoAkBsG+Wz81ltvJX29YcMGFRQU6ODBg5o+fXri+WAwqEgkkp4VAgBy1ud6T6inp0eSNHz48KTnGxsbVVBQoLFjx2rRokXq6ur61F8jHo8rFoslPQAAA0PKEXLOqba2VtOmTVNZWVni+aqqKr366qvatWuXVq9erebmZs2aNUvxePyav059fb3C4XDiUVxcnOqSAABZJuCcc6kM1tTUaPv27dq7d69GjRr1qdt1dHSopKREW7Zs0fz58696PR6PJwUqFoupuLhYFXpIgwKDU1kaAMBQvzuvRr2hnp4e5efnX3dbr/eELluyZInefPNN7dmz57oBkqRoNKqSkhK1trZe8/VgMKhgMJjKMgAAWc4rQs45LVmyRK+//roaGxtVWlp6w5nu7m61t7crGo2mvEgAQG7yek+opqZGr7zyijZv3qxQKKTOzk51dnbq7NmzkqQzZ87omWee0bvvvqvjx4+rsbFRc+fO1YgRI/Twww9n5DcAAMheXldC69evlyRVVFQkPb9hwwYtXLhQeXl5OnTokDZt2qSPPvpI0WhUM2fO1NatWxUKhdK2aABAbvD+57jrGTp0qHbu3Pm5FgQAGDi4dxwAwAwRAgCYIUIAADNECABghggBAMwQIQCAGSIEADBDhAAAZogQAMAMEQIAmCFCAAAzRAgAYIYIAQDMECEAgBkiBAAwQ4QAAGaIEADADBECAJghQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwMwg6wVcyTknSerXeckZLwYA4K1f5yX925/n13PLRai3t1eStFc7jFcCAPg8ent7FQ6Hr7tNwH2WVN1EFy9e1MmTJxUKhRQIBJJei8ViKi4uVnt7u/Lz841WaI/jcAnH4RKOwyUch0tuhePgnFNvb6+Kiop0223Xf9fnlrsSuu222zRq1KjrbpOfnz+gT7LLOA6XcBwu4ThcwnG4xPo43OgK6DI+mAAAMEOEAABmsipCwWBQK1asUDAYtF6KKY7DJRyHSzgOl3AcLsm243DLfTABADBwZNWVEAAgtxAhAIAZIgQAMEOEAABmsipCL7zwgkpLS3X77bdr4sSJ+u1vf2u9pJuqrq5OgUAg6RGJRKyXlXF79uzR3LlzVVRUpEAgoG3btiW97pxTXV2dioqKNHToUFVUVOjw4cM2i82gGx2HhQsXXnV+TJkyxWaxGVJfX69JkyYpFAqpoKBA8+bN05EjR5K2GQjnw2c5DtlyPmRNhLZu3aqlS5dq+fLlamlp0QMPPKCqqiqdOHHCemk31bhx49TR0ZF4HDp0yHpJGdfX16cJEyZo3bp113x91apVWrNmjdatW6fm5mZFIhHNnj07cR/CXHGj4yBJc+bMSTo/duzIrXswNjU1qaamRvv371dDQ4P6+/tVWVmpvr6+xDYD4Xz4LMdBypLzwWWJ+++/3z355JNJz919993uRz/6kdGKbr4VK1a4CRMmWC/DlCT3+uuvJ76+ePGii0Qi7vnnn08898knn7hwOOx++ctfGqzw5rjyODjnXHV1tXvooYdM1mOlq6vLSXJNTU3OuYF7Plx5HJzLnvMhK66Ezp07p4MHD6qysjLp+crKSu3bt89oVTZaW1tVVFSk0tJSPfroozp27Jj1kky1tbWps7Mz6dwIBoOaMWPGgDs3JKmxsVEFBQUaO3asFi1apK6uLuslZVRPT48kafjw4ZIG7vlw5XG4LBvOh6yI0KlTp3ThwgUVFhYmPV9YWKjOzk6jVd18kydP1qZNm7Rz50699NJL6uzsVHl5ubq7u62XZuby//4D/dyQpKqqKr366qvatWuXVq9erebmZs2aNUvxeNx6aRnhnFNtba2mTZumsrIySQPzfLjWcZCy53y45e6ifT1X/mgH59xVz+WyqqqqxH+PHz9eU6dO1Ze+9CVt3LhRtbW1hiuzN9DPDUlasGBB4r/Lysp03333qaSkRNu3b9f8+fMNV5YZixcv1gcffKC9e/de9dpAOh8+7Thky/mQFVdCI0aMUF5e3lV/k+nq6rrqbzwDybBhwzR+/Hi1trZaL8XM5U8Hcm5cLRqNqqSkJCfPjyVLlujNN9/U7t27k370y0A7Hz7tOFzLrXo+ZEWEhgwZookTJ6qhoSHp+YaGBpWXlxutyl48HteHH36oaDRqvRQzpaWlikQiSefGuXPn1NTUNKDPDUnq7u5We3t7Tp0fzjktXrxYr732mnbt2qXS0tKk1wfK+XCj43Att+z5YPihCC9btmxxgwcPdr/61a/c73//e7d06VI3bNgwd/z4ceul3TRPP/20a2xsdMeOHXP79+93Dz74oAuFQjl/DHp7e11LS4traWlxktyaNWtcS0uL+8Mf/uCcc+7555934XDYvfbaa+7QoUPusccec9Fo1MViMeOVp9f1jkNvb697+umn3b59+1xbW5vbvXu3mzp1qrvrrrty6jh8//vfd+Fw2DU2NrqOjo7E4+OPP05sMxDOhxsdh2w6H7ImQs4594tf/MKVlJS4IUOGuHvvvTfp44gDwYIFC1w0GnWDBw92RUVFbv78+e7w4cPWy8q43bt3O0lXPaqrq51zlz6Wu2LFCheJRFwwGHTTp093hw4dsl10BlzvOHz88ceusrLSjRw50g0ePNiNHj3aVVdXuxMnTlgvO62u9fuX5DZs2JDYZiCcDzc6Dtl0PvCjHAAAZrLiPSEAQG4iQgAAM0QIAGCGCAEAzBAhAIAZIgQAMEOEAABmiBAAwAwRAgCYIUIAADNECABghggBAMz8P2aC9402oNe/AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n"
     ]
    }
   ],
   "source": [
    "trans_digit_train = datasets.MNIST(data_path,download=False,train=True,transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307), (0.3081))\n",
    "]))\n",
    "trans_digit_test = datasets.MNIST(data_path,download=False,train=False,transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1325),(0.3105))\n",
    "]))\n",
    "\n",
    "digital,label = trans_digit_test[66]\n",
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(digital.permute(1,2,0))\n",
    "plt.show()\n",
    "print(label)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:11:10.559311600Z",
     "start_time": "2023-10-20T12:11:07.554510800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "class NetDigital(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1,6,kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(6,16,kernel_size=5)\n",
    "        #self.conv3 = nn.Conv2d(32,32,kernel_size=3)\n",
    "        self.fc1 = nn.Linear(16*4*4,120)\n",
    "        self.fc2 = nn.Linear(120,84)\n",
    "        self.fc3 = nn.Linear(84,10)\n",
    "\n",
    "    def forward(self,x):\n",
    "        out = F.max_pool2d(torch.relu(self.conv1(x)),2)\n",
    "        out = F.max_pool2d(torch.relu(self.conv2(out)),2)\n",
    "        #out = torch.tanh(self.conv3(out))\n",
    "        out = out.view(-1,16*4*4)\n",
    "        out = torch.relu(self.fc1(out))\n",
    "        out = torch.relu(self.fc2(out))\n",
    "        out = self.fc3(out)\n",
    "        return out"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:37:21.381781500Z",
     "start_time": "2023-10-20T12:37:21.374783700Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "import datetime\n",
    "import torch.optim as optim\n",
    "def train_loop(n_epochs,model,optimizer,loss_fn,train_loader):\n",
    "    for epoch in range(n_epochs):\n",
    "        loss_train = 0.0\n",
    "        for i,data in enumerate(train_loader,0):\n",
    "            imgs,labels=data\n",
    "            imgs,labels = imgs.cuda(),labels.cuda()\n",
    "            outputs = model(imgs)\n",
    "            loss = loss_fn(outputs,labels)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            loss_train+=loss.item()\n",
    "            if  i%1000==999:\n",
    "                print('Epoch:{}, Bacth:{}, 训练损失:{}'.format(epoch+1,i+1,loss_train/1000))\n",
    "                loss_train=0.0\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:39:27.858707600Z",
     "start_time": "2023-10-20T12:39:27.848681400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:1, Bacth:1000, 训练损失:1.2367929233592003\n",
      "Epoch:1, Bacth:2000, 训练损失:0.303289304717342\n",
      "Epoch:1, Bacth:3000, 训练损失:0.22574096757614462\n",
      "Epoch:1, Bacth:4000, 训练损失:0.17842700932340813\n",
      "Epoch:1, Bacth:5000, 训练损失:0.13498068395414156\n",
      "Epoch:1, Bacth:6000, 训练损失:0.12090972116337252\n",
      "Epoch:1, Bacth:7000, 训练损失:0.1098496678590509\n",
      "Epoch:1, Bacth:8000, 训练损失:0.11663113340382915\n",
      "Epoch:1, Bacth:9000, 训练损失:0.09422701711452829\n",
      "Epoch:1, Bacth:10000, 训练损失:0.10355910789245172\n",
      "Epoch:1, Bacth:11000, 训练损失:0.09300098307246663\n",
      "Epoch:1, Bacth:12000, 训练损失:0.10155656610132247\n",
      "Epoch:1, Bacth:13000, 训练损失:0.08832439306768311\n",
      "Epoch:1, Bacth:14000, 训练损失:0.08969156544678117\n",
      "Epoch:1, Bacth:15000, 训练损失:0.05660615875464464\n",
      "Epoch:2, Bacth:1000, 训练损失:0.07857162088708355\n",
      "Epoch:2, Bacth:2000, 训练损失:0.07831226753018995\n",
      "Epoch:2, Bacth:3000, 训练损失:0.07384125157734252\n",
      "Epoch:2, Bacth:4000, 训练损失:0.06585274444246625\n",
      "Epoch:2, Bacth:5000, 训练损失:0.06073078918642841\n",
      "Epoch:2, Bacth:6000, 训练损失:0.0514165165160141\n",
      "Epoch:2, Bacth:7000, 训练损失:0.06151596793004046\n",
      "Epoch:2, Bacth:8000, 训练损失:0.06539852184813617\n",
      "Epoch:2, Bacth:9000, 训练损失:0.05425424248529453\n",
      "Epoch:2, Bacth:10000, 训练损失:0.06707932643979211\n",
      "Epoch:2, Bacth:11000, 训练损失:0.05608123748461003\n",
      "Epoch:2, Bacth:12000, 训练损失:0.06138814045599816\n",
      "Epoch:2, Bacth:13000, 训练损失:0.05246066661500538\n",
      "Epoch:2, Bacth:14000, 训练损失:0.057402441345757324\n",
      "Epoch:2, Bacth:15000, 训练损失:0.038757285816098115\n"
     ]
    }
   ],
   "source": [
    "model = NetDigital().cuda()\n",
    "train_loader = torch.utils.data.DataLoader(trans_digit_train,shuffle=False,batch_size=4,num_workers=2)\n",
    "optimizer = optim.SGD(model.parameters(),lr=0.01)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "train_loop(n_epochs=2,model=model,loss_fn=loss_fn,optimizer=optimizer,train_loader=train_loader)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:47:16.858281100Z",
     "start_time": "2023-10-20T12:45:24.976288900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "outputs": [],
   "source": [
    "def test_loop(model,test_loader):\n",
    "    correct= 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for imgs,labels in test_loader:\n",
    "            imgs,labels = imgs.cuda(),labels.cuda()\n",
    "            outputs = model(imgs)\n",
    "            _,preds = torch.max(outputs,dim=1)\n",
    "            total +=labels.shape[0]\n",
    "            correct+=int((preds==labels).sum())\n",
    "    print('测试集精度:{:.3f}%'.format(100*correct/total))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:52:25.752374500Z",
     "start_time": "2023-10-20T12:52:25.741393600Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集精度:98.500%\n"
     ]
    }
   ],
   "source": [
    "test_loader = torch.utils.data.DataLoader(trans_digit_test,shuffle=False,batch_size=4)\n",
    "test_loop(model,test_loader)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:52:34.302435400Z",
     "start_time": "2023-10-20T12:52:26.294423800Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "outputs": [
    {
     "data": {
      "text/plain": "6"
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output = model(digital.cuda().unsqueeze(0))\n",
    "_,pred = torch.max(output,dim=1)\n",
    "pred.item()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-10-20T12:59:06.762437700Z",
     "start_time": "2023-10-20T12:59:06.750411900Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
